from pathlib import Path
from typing import Literal, cast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Rectangle
from matplotlib.ticker import PercentFormatter
from sklearn.metrics import auc
from torch import Tensor

from deception_detection.experiment import Experiment
from deception_detection.log import logger
from deception_detection.metrics import get_auroc
from deception_detection.prompt_scorer import PromptScorer
from deception_detection.results.utils import (
    AMBIGUOUS_COLOR,
    DECEPTIVE_COLOR,
    HONEST_COLOR,
    NEUTRAL_COLOR,
    get_control_df,
    get_dataframe_of_scores_on_dataset,
    get_dataset_color,
    save_plot_exp,
    to_readable_dataset_name,
)
from deception_detection.scores import Scores
from deception_detection.types import Label


def get_fprs_tprs(honest: Tensor, deceptive: Tensor):
    # Handle empty tensors (datasets with only one class)
    if honest.numel() == 0 or deceptive.numel() == 0:
        # Return empty FPR/TPR arrays if either class is missing
        return [], []
    min_score = min(honest.min().item(), deceptive.min().item())
    max_score = max(honest.max().item(), deceptive.max().item())
    thresholds = torch.linspace(min_score - 1, max_score, 1000)

    def positive_rates(scores: Tensor, thresholds: Tensor) -> list[float]:
        return (scores > thresholds[:, None]).float().mean(dim=-1).tolist()

    tprs = positive_rates(deceptive, thresholds)
    fprs = positive_rates(honest, thresholds)
    return fprs, tprs


def abbrev_dataset_name(dataset_name: str) -> str:
    """Map dataset names to abbreviated codes for x-axis labels."""
    s = str(dataset_name).strip().lower()
    s_norm = s.replace("_", " ").replace("-", " ").replace("  ", " ")
    
    # Gender Secret
    if "gender" in s_norm and "secret" in s_norm:
        return "GS"
    
    # Sleeper Agents
    if "sleeper" in s_norm:
        return "ST"
    
    # Insider Trading
    if "insider" in s_norm or ("trading" in s_norm and "XXXX-9" in s_norm):
        return "IT"
    
    # Instruct Dishonesty
    if ("instruct" in s_norm and "decept" in s_norm) or "instruct dishonesty" in s_norm:
        return "ID"
    
    # Harm-Refusal + Choice / Knowledge-Report
    if "harm" in s_norm or "refusal" in s_norm:
        if "choice" in s_norm:
            return "HP-C"
        if "knowledge" in s_norm or "report" in s_norm:
            return "HP-KR"
        return "HP"
    
    # Convince Me
    if "convincing" in s_norm or "convince" in s_norm:
        return "CG"
    
    # Control/Alpaca (including Alpaca XXXX-9)
    # Check for alpaca first (will catch alpaca_XYZ), then control, then XXXX-9 alone
    if "alpaca" in s_norm or "control" in s_norm or ("XXXX-9" in s_norm and "trading" not in s_norm):
        return "Control"
    
    return str(dataset_name)


def format_model_name_for_title(model_name: str) -> str:
    """Format model name to match the image format (e.g., 'llama-70b-3.3' -> 'llama-v3.3-70b-instruct')."""
    s = str(model_name).strip().lower()
    
    # Handle llama models: "llama-70b-3.3" -> "llama-v3.3-70b-instruct"
    if s.startswith("llama"):
        parts = s.split("-")
        if len(parts) >= 3:
            # Format: llama-70b-3.3 or llama-70b-3.1
            if parts[1].endswith("b") and parts[2] in ["3.3", "3.1"]:
                size = parts[1]  # e.g., "70b"
                version = parts[2]  # e.g., "3.3"
                return f"llama-v{version}-{size}-instruct"
    
    # For other models, return as-is or add -instruct if needed
    return model_name


# Fixed order for datasets matching the image: CG, GS, HP-C, HP-KR, IT, ID, ST, Control
ORDERED_DATASETS = ["CG", "GS", "HP-C", "HP-KR", "IT", "ID", "ST"]


def order_datasets(datasets: list[str]) -> list[str]:
    """Order datasets according to the fixed order, with Control always at the end."""
    # Separate Control from other datasets
    control_datasets = [d for d in datasets if d == "Control"]
    other_datasets = [d for d in datasets if d != "Control"]
    
    # Order other datasets according to ORDERED_DATASETS, then append any not in the list
    ordered_others = [d for d in ORDERED_DATASETS if d in other_datasets]
    remaining = [d for d in other_datasets if d not in ORDERED_DATASETS]
    
    # Combine: ordered datasets + remaining + control
    return ordered_others + remaining + control_datasets


def violin_plot(
    experiment: Experiment,
    scorer: PromptScorer,
    include_train: bool = True,
    exclude_datasets: set[str] | None = None,
    legend_loc: str = "lower right",
    outfile: Path | None = None,
    figsize: tuple[float, float] | None = None,
):
    """
    Create a violin plot showing deceptiveness scores by dataset and honesty/deception.
    Matches the style from the provided reference implementation.
    """
    # Define fixed colors matching the image
    HONEST_COLOR = 'lightblue'
    DECEPTIVE_COLOR = 'lightcoral'
    CONTROL_COLOR = 'lightgray'
    
    # Get datasets
    if exclude_datasets is None:
        exclude_datasets = {"ai_liar__original_without_answers", "repe_honesty__plain_with_sys_val"}
    dataset_names = [
        name
        for name in experiment.get_eval_scores()
        if name not in exclude_datasets
        and (include_train or not name.startswith(experiment.cfg.train_data))
    ]
    
    # Get raw scores for all datasets and control
    dfs: list[pd.DataFrame] = []
    for dataset_name in dataset_names:
        df = get_dataframe_of_scores_on_dataset(
            experiment, dataset_name, scorer=scorer, normalize_scores=False
        )
        df["dataset"] = abbrev_dataset_name(dataset_name)
        dfs.append(df)
    
    eval_df = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()
    control_df = get_control_df(experiment, scorer, normalize=False)
    
    # Collect all scores for scaling
    all_scores = []
    if not eval_df.empty:
        all_scores.extend(eval_df["score"].tolist())
    if not control_df.empty:
        all_scores.extend(control_df["score"].tolist())
    
    # Calculate min/max for scaling to 0-100
    if all_scores:
        min_score = min(all_scores)
        max_score = max(all_scores)
    else:
        min_score, max_score = 0, 1
    
    # Scale function: maps [min_score, max_score] to [0, 100]
    def scale_to_100(score: float) -> float:
        if max_score == min_score:
            return 50.0
        return ((score - min_score) / (max_score - min_score)) * 100
    
    # Scale scores to 0-100
    if not eval_df.empty:
        eval_df["score"] = eval_df["score"].apply(scale_to_100)
    if not control_df.empty:
        control_df["score"] = control_df["score"].apply(scale_to_100)
    
    # Create figure
    figsize = figsize or (7, 4)
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Remove grid and set white background
    ax.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_facecolor('white')
    
    # Get unique datasets and order them according to fixed order
    if not eval_df.empty:
        datasets = eval_df["dataset"].unique().tolist()
    else:
        datasets = []
    
    # Add control dataset if available
    has_control = not control_df.empty
    if has_control:
        datasets.append('Control')
    
    # Order datasets according to the fixed order (CG, GS, HP-C, HP-KR, IT, ID, ST, Control)
    datasets = order_datasets(datasets)
    
    # Create violin plots manually for better control
    for i, dataset in enumerate(datasets):
        # Calculate position with spacing between datasets
        base_position = i * 2.0
        
        if dataset == 'Control' and has_control:
            # Control data (positioned with extra distance on the right)
            control_scores = control_df["score"].values.tolist()
            if len(control_scores) > 0:
                # Create violin plot for control data
                parts = ax.violinplot([control_scores], positions=[base_position + 1.0], widths=0.3,
                                     showmeans=False, showmedians=False, showextrema=False)
                for pc in parts['bodies']:
                    pc.set_facecolor(CONTROL_COLOR)
                    pc.set_alpha(1.0)
        else:
            # Regular dataset data - one honest and one deceptive violin
            dataset_data = eval_df[eval_df["dataset"] == dataset]
            
            # Honest data (left side)
            honest_scores = dataset_data[dataset_data["label"] == "honest"]["score"].values.tolist()
            if len(honest_scores) > 0:
                parts = ax.violinplot([honest_scores], positions=[base_position - 0.3], widths=0.4,
                                    showmeans=False, showmedians=False, showextrema=False)
                for pc in parts['bodies']:
                    pc.set_facecolor(HONEST_COLOR)
                    pc.set_alpha(0.7)
            
            # Deceptive data (right side)
            deceptive_scores = dataset_data[dataset_data["label"] == "deceptive"]["score"].values.tolist()
            if len(deceptive_scores) > 0:
                parts = ax.violinplot([deceptive_scores], positions=[base_position + 0.3], widths=0.4,
                                    showmeans=False, showmedians=False, showextrema=False)
                for pc in parts['bodies']:
                    pc.set_facecolor(DECEPTIVE_COLOR)
                    pc.set_alpha(0.7)
    
    # Add horizontal line at 99th percentile of control data (1% FPR)
    if has_control and len(control_df) > 0:
        control_99th_percentile = np.percentile(control_df["score"].values, 99)
        ax.axhline(y=control_99th_percentile, color='red', linestyle='--', alpha=0.7, linewidth=1.5)
    
    # Add boxes around datasets for visual separation
    non_control_datasets = [d for d in datasets if d != 'Control']
    if non_control_datasets:
        # Calculate the leftmost and rightmost positions
        leftmost_position = 0 * 2.0 - 0.7  # First dataset left edge with padding
        rightmost_position = (len(non_control_datasets) - 1) * 2.0 + 0.7  # Last dataset right edge with padding
        
        # Get y-axis limits for proper box height
        y_min, y_max = ax.get_ylim()
        
        # Create one big box for all non-control datasets
        big_box = Rectangle((leftmost_position, y_min), rightmost_position - leftmost_position, y_max - y_min, 
                           linewidth=1, edgecolor='black', facecolor='none', alpha=0.3)
        ax.add_patch(big_box)
    
    # Add separate box for control data if it exists
    if has_control and 'Control' in datasets:
        control_index = datasets.index('Control')
        control_position = control_index * 2.0
        y_min, y_max = ax.get_ylim()
        control_box = Rectangle((control_position + 0.7, y_min), 0.6, y_max - y_min, 
                            linewidth=1, edgecolor='black', facecolor='none', alpha=0.3)
        ax.add_patch(control_box)
    
    # Customize the plot
    ax.set_xlabel('Dataset', fontsize=12)
    ax.set_ylabel('Deceptiveness Score', fontsize=12)
    
    # Set title with model name (formatted to match image)
    model_name = experiment.cfg.model_name.value
    formatted_model_name = format_model_name_for_title(model_name)
    ax.set_title(f'Self-Evaluation - {formatted_model_name}', fontsize=14, fontweight='bold', pad=15)
    
    # Set x-axis labels
    x_labels = datasets.copy()
    
    # Set x-axis positions and labels - center ticks properly
    x_positions = []
    for i in range(len(datasets)):
        if datasets[i] == 'Control':
            x_positions.append(i * 2.0 + 1.0)  # Move control tick further right
        else:
            # Position tick at base_position (center between honest and deceptive violins)
            x_positions.append(i * 2.0)
    
    ax.set_xticks(x_positions)
    ax.set_xticklabels(x_labels, rotation=45, ha='right', fontsize=10)
    
    # Set y-axis limits - adjust based on data range
    all_plot_scores = []
    if not eval_df.empty:
        all_plot_scores.extend(eval_df["score"].tolist())
    if has_control:
        all_plot_scores.extend(control_df["score"].tolist())
    
    if all_plot_scores:
        score_min = min(all_plot_scores)
        score_max = max(all_plot_scores)
        # Add some padding
        padding = (score_max - score_min) * 0.1 if score_max > score_min else 5
        ax.set_ylim(max(0, score_min - padding), min(100, score_max + padding))
    else:
        ax.set_ylim(0, 100)  # fallback
    
    # Add custom legend positioned outside the plot area
    # Create simplified legend elements
    legend_elements = []
    legend_elements.append(Patch(facecolor=HONEST_COLOR, alpha=0.7, label='Honest'))
    legend_elements.append(Patch(facecolor=DECEPTIVE_COLOR, alpha=0.7, label='Deceptive'))
    
    # Add control elements
    if has_control:
        legend_elements.append(Patch(facecolor=CONTROL_COLOR, alpha=1.0, label='Control (honest)'))
        legend_elements.append(Line2D([0], [0], color='red', linestyle='--', alpha=0.7, linewidth=1.5, label='1% FPR on control'))
    
    ax.legend(handles=legend_elements, title='Response Type', 
            bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
    
    # Adjust layout to make room for the legend
    plt.tight_layout()
    
    # Save plot
    save_plot_exp(experiment, f"violin_plot_{scorer.method}.png", outfile)


def plot_roc_comparison(
    experiment: Experiment,
    scorer: PromptScorer,
    exclude_datasets: set[str] | None = None,
    outfile: Path | None = None,
    figsize: tuple[float, float] | None = None,
    x_maxs: tuple[float, float] = (1, 0.03),
    x_line: float | None = 0.01,
    control_dataset_name: str = "alpaca__plain",
):
    fig, axs = plt.subplots(1, 2, figsize=figsize or (8, 4), sharey=True)

    all_scores = experiment.get_eval_scores()
    alpaca_scores = scorer.process(experiment.get_control_scores()[f"{control_dataset_name}"])

    for dataset_name, scores in all_scores.items():
        if exclude_datasets is not None and dataset_name in exclude_datasets:
            continue
        split_scores = scores.get_scores_by_label()
        if len(split_scores[Label.HONEST]) == 0:
            logger.warning(f"no honest scores for {dataset_name}")
            continue
        honest = scorer.process(split_scores[Label.HONEST])
        deceptive = scorer.process(split_scores[Label.DECEPTIVE])
        # Skip if either tensor is empty after processing
        if honest.numel() == 0 or deceptive.numel() == 0:
            logger.warning(f"empty scores after processing for {dataset_name}")
            continue
        name = to_readable_dataset_name(dataset_name, separate_val=True)
        color = get_dataset_color(dataset_name)
        axs[0].plot(*get_fprs_tprs(honest, deceptive), label=name, color=color)
        axs[1].plot(*get_fprs_tprs(alpaca_scores, deceptive), label=name, color=color)

    axs[0].set_xlabel("FPR on Honest")
    axs[1].set_xlabel("FPR on Control")

    axs[0].set_ylabel("Recall (TPR)")
    axs[0].set_xlim(0, x_maxs[0])
    axs[1].set_xlim(0, x_maxs[1])
    axs[0].set_ylim(0, 1)
    axs[0].set_box_aspect(1)
    axs[1].set_box_aspect(1)

    # Format ticks as percentages
    axs[0].xaxis.set_major_formatter(
        PercentFormatter(xmax=1, decimals=1 if x_maxs[0] < 0.05 else 0)
    )
    axs[0].yaxis.set_major_formatter(PercentFormatter(xmax=1, decimals=0))
    axs[1].xaxis.set_major_formatter(
        PercentFormatter(xmax=1, decimals=1 if x_maxs[1] < 0.05 else 0)
    )
    axs[1].yaxis.set_major_formatter(PercentFormatter(xmax=1, decimals=0))

    axs[0].plot([0, 1], [0, 1], color="k", linestyle="--", linewidth=1)
    axs[1].plot([0, 1], [0, 1], color="k", linestyle="--", linewidth=1)
    if x_line:
        axs[1].axvline(x_line, color="k", linestyle="-", linewidth=1)

    axs[1].legend(loc="lower right")

    save_plot_exp(experiment, f"all_roc_plots_{scorer.method}.png", outfile)
    plt.close(fig)


def roleplaying_violin(
    exp: Experiment,
    scorer: PromptScorer,
    outfile: Path | None = None,
    figsize: tuple[float, float] = (5, 4),
):
    """Violin plot of roleplaying scores by deceptive grade, with alpaca scores as control"""
    df = get_dataframe_of_scores_on_dataset(exp, "roleplaying__plain", scorer=scorer)
    control_df = get_control_df(exp, scorer)
    control_df["model_graded_deception"] = "Control"
    combined_df = pd.concat([df, control_df])

    plt.figure(figsize=figsize)
    palette = [HONEST_COLOR] * 2 + [AMBIGUOUS_COLOR] * 3 + [DECEPTIVE_COLOR] * 2 + [NEUTRAL_COLOR]
    sns.violinplot(
        data=combined_df,
        x="model_graded_deception",
        y="score",
        order=[1, 2, 3, 4, 5, 6, 7, "Control"],
        hue_order=[1, 2, 3, 4, 5, 6, 7, "Control"],
        palette=palette,
        hue="model_graded_deception",
        legend=False,
        cut=0,
    )
    # plt.title("Roleplaying vs Alpaca")
    plt.xlabel("Model Graded Deception")
    plt.ylabel("Probe Deceptiveness Score")
    plt.yticks([])

    # compute auroc reclassifying ambiguous
    honest_scores = df[df["model_graded_deception"] <= 4]["score"]
    deceptive_scores = df[df["model_graded_deception"] > 4]["score"]
    auroc = get_auroc(honest_scores, deceptive_scores)  # type: ignore
    print(f"AUROC reclassifying ambiguous: {auroc:.3f}")
    threshold = np.quantile(control_df["score"], 0.99)  # type: ignore
    recall_at_threshold = (deceptive_scores > threshold).mean()
    print(f"Recall at threshold: {recall_at_threshold:.3f}")

    save_plot_exp(exp, f"roleplaying_{scorer.method}.png", outfile)


def ai_audit_violin(
    exp: Experiment,
    scorer: PromptScorer,
    outfile: Path | None = None,
    figsize: tuple[float, float] = (5, 2.5),
) -> None:
    """Create violin plot for AI Audit data, separated by experiment and variation."""
    normalize = True
    df = get_dataframe_of_scores_on_dataset(exp, "ai_audit__a1_2_reasoning", scorer, normalize)
    df["label"] = df["label"].map(lambda label: label.title())  # type: ignore
    control_df = get_control_df(exp, scorer)
    control_df["label"] = "Control"
    df = pd.concat([df, control_df])

    plt.figure(figsize=figsize)
    sns.violinplot(
        data=df,
        y="score",
        hue="label",
        cut=0,
        palette=[HONEST_COLOR, DECEPTIVE_COLOR, NEUTRAL_COLOR],
        hue_order=["Honest", "Deceptive", "Control"],
        inner=None,
    )
    plt.yticks([])
    plt.xticks([])
    plt.ylabel("Probe deceptiveness score")
    plt.legend(title="", loc="upper left")

    scores = exp.get_eval_scores()["ai_audit__a1_2_reasoning"]
    auroc = scorer.auroc(scores)
    recall_at_threshold = scorer.recall_at_fpr(scores, 0.01)
    plt.axhline(y=0, color="k", linewidth=1)
    # annotate with auroc in lower left
    plt.text(0.05, 0.05, f"AUROC: {auroc:.3f}", transform=plt.gca().transAxes)
    plt.text(0.05, 0.15, f"Recall @1%: {recall_at_threshold:.3f}", transform=plt.gca().transAxes)
    plt.tight_layout()
    save_plot_exp(exp, f"ai_audit_violin_{scorer.method}.png", outfile)


def plot_roc_curve_with_auroc(
    experiment: Experiment,
    scorer: PromptScorer,
    compare_to: Literal["honest", "control"] = "honest",
    exclude_datasets: set[str] | None = None,
    x_max: float = 1.0,
    include_legend: bool = False,
    include_title: bool = False,
    legend_outside: bool = True,
    legend_with_auc: bool = True,
    figsize: tuple[float, float] = (3.5, 3.5),
    xline: float | None = 0.01,
    outfile: Path | None = None,
    control_dataset_name: str = "alpaca__plain",
) -> None:
    """Plot ROC curves for deceptive vs honest/control with AUROC values in legend.

    Args:
        experiment: The experiment containing the scores
        scorer: The scorer to process the scores
        compare_to: Whether to compare deceptive examples against honest or control examples
        exclude_datasets: Optional set of dataset names to exclude
    """
    plt.figure(figsize=figsize)
    ax = plt.gca()

    if legend_outside:
        # Move the main axis to make room for legend
        box = ax.get_position()
        ax.set_position((box.x0, box.y0, box.width * 0.8, box.height))

    exclude_datasets = exclude_datasets or set()
    legend_elements = []

    # Get control scores if needed
    control_scores = None
    if compare_to == "control":
        control_scores = scorer.process(experiment.get_control_scores()[f"{control_dataset_name}"])

    for dataset_name, scores in experiment.get_eval_scores().items():
        if dataset_name in exclude_datasets or dataset_name.startswith(experiment.cfg.train_data):
            continue

        split_scores = scores.get_scores_by_label()
        deceptive = scorer.process(split_scores[Label.DECEPTIVE])
        
        # Skip if no deceptive samples
        if deceptive.numel() == 0:
            logger.warning(f"no deceptive scores for {dataset_name}")
            continue

        if compare_to == "honest":
            if len(split_scores[Label.HONEST]) == 0:
                logger.warning(f"no honest scores for {dataset_name}")
                continue
            comparison_scores = scorer.process(split_scores[Label.HONEST])
            # Double-check after processing (in case processing results in empty tensor)
            if comparison_scores.numel() == 0:
                logger.warning(f"no honest scores after processing for {dataset_name}")
                continue
        else:
            assert control_scores is not None
            comparison_scores = control_scores

        fprs, tprs = get_fprs_tprs(comparison_scores, deceptive)
        # Skip if we got empty arrays (shouldn't happen after above checks, but be safe)
        if len(fprs) == 0 or len(tprs) == 0:
            logger.warning(f"empty FPR/TPR arrays for {dataset_name}, skipping")
            continue
        auroc = auc(fprs, tprs)

        name = to_readable_dataset_name(dataset_name)
        color = get_dataset_color(dataset_name)

        plt.plot(fprs, tprs, color=color)[0]
        rect = Rectangle((0, 0), 1, 1, fc=color)
        auroc_str = f"{auroc:.3f}" if auroc > 0.99 else f"{auroc:.2f}"
        if legend_with_auc:
            legend_elements.append((rect, f"{name} ({auroc_str})"))
        else:
            legend_elements.append((rect, name))

    plt.plot([0, x_max], [0, x_max], color="k", linestyle="--", linewidth=1)
    plt.xlim(0, x_max)
    plt.ylim(0, 1)
    plt.gca().set_box_aspect(1)
    plt.xlabel(f"FPR on {'Honest' if compare_to == 'honest' else 'Control'} (%)")
    plt.ylabel("Recall (TPR) (%)")
    if include_title:
        plt.title(f"ROC Curves: Deceptive vs {'Honest' if compare_to == 'honest' else 'Control'}")

    if xline:
        plt.axvline(xline, color="k", linestyle="-", linewidth=1)

    if include_legend:
        # Add legend to the right of the plot
        plt.legend(
            [elem[0] for elem in legend_elements],
            [elem[1] for elem in legend_elements],
            loc="upper left" if legend_outside else "lower right",
            bbox_to_anchor=(1.02, 1) if legend_outside else None,
            title="Eval. Dataset",
        )

    comparison_type = "honest" if compare_to == "honest" else "control"
    save_plot_exp(experiment, f"roc_with_auroc_{comparison_type}_{scorer.method}.png", outfile)
    plt.close()


def ai_liar_roc_plot(
    exp: Experiment,
    scorer: PromptScorer,
    outfile: Path | None = None,
    figsize: tuple[float, float] = (7, 3.5),
    label_text_fontsize: float = 10,
):
    fig, axs = plt.subplots(1, 2, figsize=figsize, sharey=True)

    dataset_name = "ai_liar__original_without_answers"
    alpaca_scores = scorer.process(exp.get_control_scores()["alpaca__plain"])
    scores = exp.get_eval_scores()[dataset_name]
    split_scores = scores.get_scores_by_label()
    if len(split_scores[Label.HONEST]) == 0:
        logger.warning(f"no honest scores for {dataset_name}")
        return
    honest = scorer.process(split_scores[Label.HONEST])
    deceptive = scorer.process(split_scores[Label.DECEPTIVE])
    # Skip if either tensor is empty after processing
    if honest.numel() == 0 or deceptive.numel() == 0:
        logger.warning(f"empty scores after processing for {dataset_name}")
        return
    name = to_readable_dataset_name(dataset_name)
    color = get_dataset_color(dataset_name)
    axs[0].plot(*get_fprs_tprs(honest, deceptive), label=name, color=color)
    axs[1].plot(*get_fprs_tprs(alpaca_scores, deceptive), label=name, color=color)

    # axs[0].set_title("ROC Curve vs Honest")
    axs[0].set_xlabel("FPR on Honest")
    axs[0].set_ylabel("Recall (TPR)")
    axs[0].set_xlim(0, 1)
    axs[0].set_ylim(0, 1)
    axs[0].set_box_aspect(1)
    axs[0].plot([0, 1], [0, 1], color="k", linestyle="--", linewidth=1)

    # axs[1].set_title("ROC Curve vs Alpaca")
    axs[1].set_xlabel("FPR on Alpaca")
    axs[1].set_box_aspect(1)

    axs[1].set_xlim(0, 0.03)
    axs[1].plot([0, 1], [0, 1], color="k", linestyle="--", linewidth=1)
    axs[1].axvline(0.01, color="k", linestyle="-", linewidth=1)
    axs[1].set_xticks([0, 0.01, 0.02, 0.03])

    # annotate with auc
    fprs, tprs = get_fprs_tprs(honest, deceptive)
    auroc = auc(fprs, tprs)
    text_args = dict(
        fontsize=label_text_fontsize,
        ha="right",
    )
    axs[0].text(0.95, 0.07, f"AUROC: {auroc:.3f}", transform=axs[0].transAxes, **text_args)

    fprs, tprs = get_fprs_tprs(alpaca_scores, deceptive)
    auroc = auc(fprs, tprs)
    import numpy as np

    fpr_at_01_idx = np.argmin(np.abs(np.array(fprs) - 0.01))
    tpr_at_01 = tprs[fpr_at_01_idx]
    axs[1].text(0.95, 0.15, f"Recall @1%: {tpr_at_01:.3f}", transform=axs[1].transAxes, **text_args)
    axs[1].text(0.95, 0.07, f"AUROC: {auroc:.3f}", transform=axs[1].transAxes, **text_args)

    plt.tight_layout()
    save_plot_exp(exp, "ai_liar_roc_plot.png", outfile)
    plt.close(fig)


def ai_liar_violin(
    exp: Experiment,
    scorer: PromptScorer,
    outfile: Path | None = None,
    figsize: tuple[float, float] = (6, 4),
    show_yticks: bool = False,
    title: str = "Roleplaying prompts from Pacchiardi et al. (2023)",
):
    df = get_dataframe_of_scores_on_dataset(exp, "ai_liar__original_without_answers", scorer)

    def get_category(row: pd.Series) -> str:  # type: ignore
        if row["label"] == "honest" and row["expected_label"] == "honest":
            return "honest, honest"
        elif row["label"] == "honest" and row["expected_label"] == "deceptive":
            return "honest, deceptive"
        elif row["label"] == "deceptive":
            return "deceptive"
        elif row["label"] == "skip" or row["label"] == "ambiguous":
            return "skip"
        raise AssertionError(row)

    df["category"] = df.apply(get_category, axis=1)  # type: ignore

    def get_label(category: str) -> str:
        counts = len(df[df["category"] == category])  # type: ignore
        return {
            "honest, honest": "honest response to\n honest prompt",
            "honest, deceptive": "honest response to\n deceptive prompt",
            "deceptive": "deceptive response",
            "skip": "skipped",
        }[category] + f"\n(n={counts})"

    plt.figure(figsize=figsize)
    sns.violinplot(
        data=df,
        x="category",
        hue="category",
        y="score",
        order=[
            "honest, honest",
            "honest, deceptive",
            "deceptive",
        ],
        palette=[HONEST_COLOR, HONEST_COLOR, DECEPTIVE_COLOR],
        legend=False,
        formatter=get_label,
    )
    if title:
        plt.title(title)
    plt.ylabel("Probe deceptiveness score")
    plt.xlabel("")
    if not show_yticks:
        plt.yticks([])
    plt.tight_layout()
    plt.axhline(0, color="k", linewidth=1)
    save_plot_exp(exp, "ai_liar_violin.png", outfile)
    plt.close()


def reclassify_ambigious(exp: Experiment, scorer: PromptScorer):
    recalssify_as = {
        "insider_trading__upscale": Label.HONEST,
        "insider_trading_doubledown__upscale": Label.HONEST,
        "sandbagging_v2__wmdp_mmlu": Label.DECEPTIVE,
    }

    eval_scores = exp.get_eval_scores()
    for dataset_name, label in recalssify_as.items():
        # Skip if dataset is not in the current experiment
        if dataset_name not in eval_scores:
            continue
            
        by_orig_label = eval_scores[dataset_name].get_scores_by_label()
        new_ambiguous = by_orig_label[Label.AMBIGUOUS]
        new_ambiguous.labels = [label] * len(new_ambiguous)
        new_scores = Scores.from_list(
            [by_orig_label[Label.HONEST], by_orig_label[Label.DECEPTIVE], new_ambiguous]
        )

        auroc = scorer.auroc(new_scores)
        print(f"{dataset_name} reclassify ambigious auroc: {auroc:.3f}")
        recall = scorer.recall_at_fpr(new_scores, 0.01)
        print(f"{dataset_name} reclassify ambigious recall: {recall:.3f}")


def make_dataset_specific_plots(
    exp: Experiment,
    scorer: PromptScorer,
    outfile: Path | None = None,
):
    if "ai_liar__original_without_answers" in exp.get_eval_scores():
        ai_liar_roc_plot(exp, scorer, outfile)
        ai_liar_violin(exp, scorer, outfile)
    if "ai_audit__a1_2reasoning" in exp.get_eval_scores():
        ai_audit_violin(exp, scorer, outfile)
    if "roleplaying__plain" in exp.get_eval_scores():
        roleplaying_violin(exp, scorer, outfile)
    reclassify_ambigious(exp, scorer)
